{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e8cba0b6",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/coding/text2sql/quickstart.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>  \n",
    "\n",
    "## Quick Demo of Text2SQL Using Llama 3.3\n",
    "\n",
    "This demo shows how to use Llama 3.3 to answer questions about a SQLite DB. \n",
    "\n",
    "We'll use LangChain and the Llama cloud provider [Together.ai](https://api.together.ai/), where you can easily get a free API key (or you can use any other Llama cloud provider or even Ollama running Llama locally)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33fb3190-59fb-4edd-82dd-f20f6eab3e47",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fa4562d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from langchain_together import ChatTogether\n",
    "\n",
    "os.environ['TOGETHER_API_KEY'] = 'your_api_key'\n",
    "\n",
    "llm = ChatTogether(\n",
    "    model=\"meta-llama/Llama-3.3-70B-Instruct-Turbo\",\n",
    "    temperature=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d421ae7",
   "metadata": {},
   "source": [
    "To recreate the `nba_roster.db` file, run the two commands below:\n",
    "- `python txt2csv.py` to convert the `nba.txt` file to `nba_roster.csv`. The `nba.txt` file was created by scraping the NBA roster info from the web.\n",
    "- `python csv2db.py` to convert `nba_roster.csv` to `nba_roster.db`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56f0360e-fca3-49a8-9a70-0416f84e15fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# uncomment if you don't want to create the db yourself\n",
    "#! wget https://github.com/meta-llama/llama-recipes/raw/3649841b426999fdc61c30a9fc8721106bec769b/recipes/use_cases/coding/text2sql/nba_roster.db"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3bb99f39-cd7a-4db6-91dd-02f3bf80347c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "# Note: to run in Colab, you need to upload the nba_roster.db file in the repo to the Colab folder first.\n",
    "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info=0)\n",
    "\n",
    "def get_schema():\n",
    "    return db.get_table_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8d793ce7-324b-4861-926c-54973d7c9b43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Based on the table schema below, write a SQL query that would answer the user's question; just return the SQL query and nothing else.\n",
      "\n",
      "Scheme:\n",
      "\n",
      "CREATE TABLE nba_roster (\n",
      "\t\"Team\" TEXT, \n",
      "\t\"NAME\" TEXT, \n",
      "\t\"Jersey\" TEXT, \n",
      "\t\"POS\" TEXT, \n",
      "\t\"AGE\" INTEGER, \n",
      "\t\"HT\" TEXT, \n",
      "\t\"WT\" TEXT, \n",
      "\t\"COLLEGE\" TEXT, \n",
      "\t\"SALARY\" TEXT\n",
      ")\n",
      "\n",
      "Question: What team is Stephen Curry on?\n",
      "\n",
      "SQL Query:\n"
     ]
    }
   ],
   "source": [
    "question = \"What team is Stephen Curry on?\"\n",
    "prompt = f\"\"\"Based on the table schema below, write a SQL query that would answer the user's question; just return the SQL query and nothing else.\n",
    "\n",
    "Scheme:\n",
    "{get_schema()}\n",
    "\n",
    "Question: {question}\n",
    "\n",
    "SQL Query:\"\"\"\n",
    "\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "70776558",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SELECT Team FROM nba_roster WHERE NAME = 'Stephen Curry'\n"
     ]
    }
   ],
   "source": [
    "answer = llm.invoke(prompt).content\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afcf423a",
   "metadata": {},
   "source": [
    "***Note:*** If you don't have the \"just return the SQL query and nothing else\" in the prompt above, you'll likely get more text other than the SQL query back in the answer, making some extra post-processing necessary before running the db query below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "62472ce6-794b-4a61-b88c-a1e031e28e4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[('Golden State Warriors',)]\""
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# note this is a dangerous operation and for demo purpose only; in production app you'll need to safe-guard any DB operation\n",
    "result = db.run(answer)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "39ed4bc3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "I don't have enough information to determine whose salary you are referring to. Could you please provide more context or specify the person you are asking about?\n"
     ]
    }
   ],
   "source": [
    "# how about a follow up question\n",
    "follow_up = \"What's his salary?\"\n",
    "print(llm.invoke(follow_up).content)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98b2c523",
   "metadata": {},
   "source": [
    "Since we did not pass any context along with the follow-up to Llama, it doesn't know the answer. Let's try to fix it by adding context to the follow-up prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0c305278-29d2-4e88-9b3d-ad67c94ce0f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Based on the table schema, question, SQL query, and SQL response below, write a new SQL response; be concise, just output the SQL response.\n",
      "\n",
      "Scheme:\n",
      "\n",
      "CREATE TABLE nba_roster (\n",
      "\t\"Team\" TEXT, \n",
      "\t\"NAME\" TEXT, \n",
      "\t\"Jersey\" TEXT, \n",
      "\t\"POS\" TEXT, \n",
      "\t\"AGE\" INTEGER, \n",
      "\t\"HT\" TEXT, \n",
      "\t\"WT\" TEXT, \n",
      "\t\"COLLEGE\" TEXT, \n",
      "\t\"SALARY\" TEXT\n",
      ")\n",
      "\n",
      "Question: What's his salary?\n",
      "SQL Query: What team is Stephen Curry on?\n",
      "SQL Result: [('Golden State Warriors',)]\n",
      "\n",
      "New SQL Response:\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompt = f\"\"\"Based on the table schema, question, SQL query, and SQL response below, write a new SQL response; be concise, just output the SQL response.\n",
    "\n",
    "Scheme:\n",
    "{get_schema()}\n",
    "\n",
    "Question: {follow_up}\n",
    "SQL Query: {question}\n",
    "SQL Result: {result}\n",
    "\n",
    "New SQL Response:\n",
    "\"\"\"\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "03739b96-e607-4fa9-bc5c-df118198dc7f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SELECT SALARY FROM nba_roster WHERE NAME = \"Stephen Curry\"\n"
     ]
    }
   ],
   "source": [
    "new_answer = llm.invoke(prompt).content\n",
    "print(new_answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c782abb6-3b44-45be-8694-70fc29b82523",
   "metadata": {},
   "source": [
    "Because we have \"be concise, just output the SQL response\", Llama 3 is able to just generate the SQL statement; otherwise output parsing will be needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6ecfca53-be7e-4668-bad1-5ca7571817d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[('$51,915,615',)]\""
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "db.run(new_answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d79bbb1-e91d-4b56-b6ef-98c94ff414d0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
